import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import gym
import numpy as np

from stable_baselines3.common import base_class
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped



def update_episode_metrics():
    pass


def update_current_metrics(current_metrics, additional_metrics, rewards, infos):
    current_metrics["rewards"] += rewards
    current_metrics["lengths"] += 1
    for metric in additional_metrics:
        if "max" in metric:
            current_metrics[metric][0] = np.max((current_metrics[metric][0], infos[0][metric.replace("_max", "")]))
        else:
            current_metrics[metric] += infos[0][metric]

    return current_metrics


def evaluate_policy(
    model: "base_class.BaseAlgorithm",
    env: Union[gym.Env, VecEnv],
    n_eval_episodes: int = 10,
    deterministic: bool = True,
    render: bool = False,
    callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
    reward_threshold: Optional[float] = None,
    return_episode_rewards: bool = False,
    warn: bool = True
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
    """
    Runs policy for ``n_eval_episodes`` episodes and returns average reward.
    If a vector env is passed in, this divides the episodes to evaluate onto the
    different elements of the vector env. This static division of work is done to
    remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
    details and discussion.

    .. note::
        If environment has not been wrapped with ``Monitor`` wrapper, reward and
        episode lengths are counted as it appears with ``env.step`` calls. If
        the environment contains wrappers that modify rewards or episode lengths
        (e.g. reward scaling, early episode reset), these will affect the evaluation
        results as well. You can avoid this by wrapping environment with ``Monitor``
        wrapper before anything else.

    :param model: The RL agent you want to evaluate.
    :param env: The gym environment or ``VecEnv`` environment.
    :param n_eval_episodes: Number of episode to evaluate the agent
    :param deterministic: Whether to use deterministic or stochastic actions
    :param render: Whether to render the environment or not
    :param callback: callback function to do additional checks,
        called after each step. Gets locals() and globals() passed as parameters.
    :param reward_threshold: Minimum expected reward per episode,
        this will raise an error if the performance is not met
    :param return_episode_rewards: If True, a list of rewards and episode lengths
        per episode will be returned instead of the mean.
    :param warn: If True (default), warns user about lack of a Monitor wrapper in the
        evaluation environment.
    :return: Mean reward per episode, std of reward per episode.
        Returns ([float], [int]) when ``return_episode_rewards`` is True, first
        list containing per-episode rewards and second containing per-episode lengths
        (in number of steps).
    """
    is_monitor_wrapped = False
    # Avoid circular import
    from stable_baselines3.common.monitor import Monitor

    if not isinstance(env, VecEnv):
        env = DummyVecEnv([lambda: env])

    is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]

    if not is_monitor_wrapped and warn:
        warnings.warn(
            "Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
            "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
            "Consider wrapping environment first with ``Monitor`` wrapper.",
            UserWarning,
        )

    n_envs = env.num_envs

    basic_metrics = ["rewards", "lengths"]

    if hasattr(env.envs[0], "additional_metrics"):
        additional_metrics = env.envs[0].additional_metrics
    else:
        additional_metrics = []

    if "distance_inf_n" in additional_metrics:
        additional_metrics += ["distance_inf_n_max"]
        additional_metrics += ["distance_l2_n_max"]

    if "distance_inf_to_true_n" in additional_metrics:
        additional_metrics += ["distance_inf_to_true_n_max"]
        additional_metrics += ["distance_l2_to_true_n_max"]

    metrics = basic_metrics + additional_metrics

    episode_metrics = {metric: [] for metric in metrics}
    current_metrics = {metric: np.zeros(n_envs) for metric in metrics}

    episode_counts = np.zeros(n_envs, dtype="int")
    # Divides episodes among different sub environments in the vector as evenly as possible
    episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")

    # current_rewards = np.zeros(n_envs)
    # current_lengths = np.zeros(n_envs, dtype="int")

    # set potential other agent in environment likewise into deterministic mode (if applicable)
    assert len(env.envs) == 1
    if hasattr(env.envs[0], "other_is_deterministic"):
        other_is_deterministic_before = env.envs[0].other_is_deterministic
        assert deterministic
        for _env in env.envs:
            _env.other_is_deterministic = True
    else:
        other_is_deterministic_before = None

    observations = env.reset()
    states = None
    episode_starts = np.ones((env.num_envs,), dtype=bool)
    while (episode_counts < episode_count_targets).any():
        actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
        observations, rewards, dones, infos = env.step(actions)

         # update_current_metrics would have to be rewritten for multiple environments, as the infos cannot simply be added up
        current_metrics = update_current_metrics(current_metrics, additional_metrics, rewards, infos)

        for i in range(n_envs):
            if episode_counts[i] < episode_count_targets[i]:

                # unpack values so that the callback can access the local variables
                reward = rewards[i]
                done = dones[i]
                info = infos[i]
                episode_starts[i] = done

                if callback is not None:
                    callback(locals(), globals())

                if dones[i]:
                    print("episode done")
                    if is_monitor_wrapped:
                        pass
                        # Atari wrapper can send a "done" signal when
                        # the agent loses a life, but it does not correspond
                        # to the true end of episode

                        # if not info["episode"]["r"] == current_metrics["rewards"][i] or "episode" not in info.keys() or not info["episode"]["l"] == current_metrics["lengths"][i]:
                        #     print("Error in rewards in evaluation.py line 134")


                    #     if "episode" in info.keys():
                    #         # Do not trust "done" with episode endings.
                    #         # Monitor wrapper includes "episode" key in info if environment
                    #         # has been wrapped with it. Use those rewards instead.
                    #         episode_metrics["episode_rewards"].append(info["episode"]["r"])
                    #         episode_metrics["episode_lengths"].append(info["episode"]["l"])
                    #         # Only increment at the real end of an episode`
                    #         episode_counts[i] += 1
                    # # else:


                    for metric in list(episode_metrics.keys()):
                        episode_metrics[metric].append(current_metrics[metric][i])
                    for metric in list(current_metrics.keys()):
                        current_metrics[metric][i] = 0

                    episode_counts[i] += 1

        if render:
            env.render()

    # set other agent in wrapped environment back to previous mode
    if other_is_deterministic_before is not None:
        for _env in env.envs:
            _env.other_is_deterministic = other_is_deterministic_before

    mean_reward = np.mean(episode_metrics["rewards"])
    std_reward = np.std(episode_metrics["rewards"])

    if reward_threshold is not None:
        assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
    if return_episode_rewards:
        # return episode_rewards, episode_lengths
        return episode_metrics

    print(f"mean_reward: {mean_reward}, std_reward: {std_reward}")

    return mean_reward, std_reward
